/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef EXAMPLES_MATRIX_MATMUL_IBSHARE_CUSTOM_IMPL_H
#define EXAMPLES_MATRIX_MATMUL_IBSHARE_CUSTOM_IMPL_H
#include "kernel_operator.h"
#include "lib/matmul_intf.h"

constexpr int32_t MIX_RATIO = 2;  // AIC:AIV = 1:2

template <const bool ibshareA, const bool ibshareB>
__aicore__ inline constexpr MatmulConfig GetCFG()
{
    if constexpr ((ibshareA && ibshareB) || (!ibshareA && !ibshareB)) {
        return GetNormalConfig();
    }
    return GetIBShareNormConfig();
}

template <typename aType, typename bType, typename cType, typename biasType, const bool ibshareA, const bool ibshareB>
class MatmulIbshareKernel {
    private:
        constexpr static MatmulConfig MM_CFG = GetCFG<ibshareA, ibshareB>();
    public:
        __aicore__ inline MatmulIbshareKernel(){};
        __aicore__ inline void Init(GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling);
        template <bool hasBias = false>
        __aicore__ inline void Process(AscendC::TPipe* pipe);
        AscendC::Matmul<
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, aType, false, LayoutMode::NONE, ibshareA>,
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, bType, false, LayoutMode::NONE, ibshareB>,
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, cType>,
            AscendC::MatmulType<AscendC::TPosition::GM, CubeFormat::ND, biasType>, MM_CFG> matmulObj;

    private:
        __aicore__ inline void CalcOffset(int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB,
                                          int32_t& offsetC, int32_t& offsetBias);

        AscendC::GlobalTensor<aType> aGlobal;
        AscendC::GlobalTensor<bType> bGlobal;
        AscendC::GlobalTensor<cType> cGlobal;
        AscendC::GlobalTensor<biasType> biasGlobal;
        TCubeTiling tiling;
};

template <typename aType, typename bType, typename cType, typename biasType, const bool ibshareA, const bool ibshareB>
__aicore__ inline void MatmulIbshareKernel<aType, bType, cType, biasType, ibshareA, ibshareB>::Init(
    GM_ADDR a, GM_ADDR b, GM_ADDR bias, GM_ADDR c, GM_ADDR workspace, const TCubeTiling& tiling)
{
    this->tiling = tiling;
    aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ aType*>(a), tiling.M * tiling.Ka);
    bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ bType*>(b), tiling.Kb * tiling.N);
    cGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ cType*>(c), tiling.M * tiling.N);
    biasGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ biasType*>(bias), tiling.N);

    int32_t offsetA = 0;
    int32_t offsetB = 0;
    int32_t offsetC = 0;
    int32_t offsetBias = 0;
    CalcOffset(AscendC::GetBlockIdx(), tiling, offsetA, offsetB, offsetC, offsetBias);
    aGlobal = aGlobal[offsetA];
    bGlobal = bGlobal[offsetB];
    cGlobal = cGlobal[offsetC];
    biasGlobal = biasGlobal[offsetBias];
    if(GetSysWorkSpacePtr() == nullptr){
        return;
    }
}

template <typename aType, typename bType, typename cType, typename biasType, const bool ibshareA, const bool ibshareB>
template <bool hasBias>
__aicore__ inline void MatmulIbshareKernel<aType, bType, cType, biasType, ibshareA, ibshareB>::Process(AscendC::TPipe* pipe)
{
    matmulObj.SetTensorA(aGlobal);
    matmulObj.SetTensorB(bGlobal);
    if constexpr (hasBias) {
        matmulObj.SetBias(biasGlobal);
    }
    matmulObj.IterateAll(cGlobal);
    matmulObj.End();
}

__aicore__ inline uint32_t Ceiling(uint32_t a, uint32_t b)
{
    return (a + b - 1) / b;
}

template <typename aType, typename bType, typename cType, typename biasType, const bool ibshareA, const bool ibshareB>
__aicore__ inline void MatmulIbshareKernel<aType, bType, cType, biasType, ibshareA, ibshareB>::CalcOffset(
    int32_t blockIdx, const TCubeTiling& tiling, int32_t& offsetA, int32_t& offsetB, int32_t& offsetC, int32_t& offsetBias)
{
    if constexpr(ibshareA && ibshareB) {
        blockIdx /= MIX_RATIO;
    }

    auto mSingleBlocks = Ceiling(tiling.M, tiling.singleCoreM);
    auto mCoreIndx = blockIdx % mSingleBlocks;
    auto nCoreIndx = blockIdx / mSingleBlocks;

    offsetA = mCoreIndx * tiling.Ka * tiling.singleCoreM;
    offsetB = nCoreIndx * tiling.singleCoreN;
    offsetC = mCoreIndx * tiling.N * tiling.singleCoreM + nCoreIndx * tiling.singleCoreN;
    offsetBias = nCoreIndx * tiling.singleCoreN;

    // process with tail block
    int32_t tailM = tiling.M - mCoreIndx * tiling.singleCoreM;
    tailM = tailM < tiling.singleCoreM ? tailM : tiling.singleCoreM;
    int32_t tailN = tiling.N - nCoreIndx * tiling.singleCoreN;
    tailN = tailN < tiling.singleCoreN ? tailN : tiling.singleCoreN;
    if (tailM < tiling.singleCoreM || tailN < tiling.singleCoreN) {
        matmulObj.SetTail(tailM, tailN);
    }
}
#endif // EXAMPLES_MATRIX_MATMUL_IBSHARE_CUSTOM_IMPL_H